# 实现滑动窗口你和多项式

import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import cv2

# 导入图片
binary_warped = mpimg.imread('warped-example.jpg')

def find_lane_pixels(binary_warped):
    
    #--------------------------拆分两条线的直方图---------------------------
    # 获取图像下半个部分得图象
    histogram = np.sum(binary_warped[binary_warped.shape[0]//2:,:], axis=0)
    
    # 创建一个 out_img 用于图形输出并可视化结果
    out_img = np.dstack((binary_warped, binary_warped, binary_warped))
    
    # 找到左右半区最大的值
    # 这是左右线得起点
    midpoint = np.int(histogram.shape[0]//2)
    leftx_base = np.argmax(histogram[:midpoint])
    rightx_base = np.argmax(histogram[midpoint:]) + midpoint

    #--------------------------设置窗口和窗口超参数---------------------------
    # 超参数
    # 选择滑动窗口得数量
    nwindows = 9

    # 设置窗口得宽度范围边距
    margin = 100

    # 找到一个最小的像素数设置为较近的窗口
    minpix = 50

    # 设置窗口得高度
    window_height = np.int(binary_warped.shape[0]//nwindows)
    
    # 识别图像中所有已激活像素得x和y位置
    nonzero = binary_warped.nonzero()
    nonzeroy = np.array(nonzero[0])
    nonzerox = np.array(nonzero[1])
    
    # 为每个窗口更新位置
    leftx_current = leftx_base
    rightx_current = rightx_base

    # 创建两个空列表解说左右车道像素索引
    left_lane_inds = []
    right_lane_inds = []

    #--------------------------迭代nwindows跟踪曲率---------------------------
    # 遍历所有得窗口 跟踪曲率
    for window in range(nwindows):
        # 识别窗口得边界
        win_y_low = binary_warped.shape[0] - (window + 1) * window_height
        win_y_high = binary_warped.shape[0] - window * window_height
        win_xleft_low = leftx_current - margin
        win_xleft_high = leftx_current + margin
        win_xright_low = rightx_current - margin
        win_xright_high = rightx_current + margin
        
        # 在可视化图像上绘制窗口
        cv2.rectangle(out_img,(win_xleft_low,win_y_low),
        (win_xleft_high,win_y_high),(0,255,0), 2) 
        cv2.rectangle(out_img,(win_xright_low,win_y_low),
        (win_xright_high,win_y_high),(0,255,0), 2) 
        
        # 识别窗口内的 x 和 y 中的非零像素
        good_left_inds = ((nonzeroy >= win_y_low) & (nonzeroy < win_y_high) & 
        (nonzerox >= win_xleft_low) &  (nonzerox < win_xleft_high)).nonzero()[0]
        good_right_inds = ((nonzeroy >= win_y_low) & (nonzeroy < win_y_high) & 
        (nonzerox >= win_xright_low) &  (nonzerox < win_xright_high)).nonzero()[0]
        
        # 将上一步获得得非零像素得索引添加到列表中
        left_lane_inds.append(good_left_inds)
        right_lane_inds.append(good_right_inds)
        
        # 如果 inds > mimpix 非零像素大于最小像素 
        # 那么就将窗口重新定位，放在他们的平均位置上
        if len(good_left_inds) > minpix:
            leftx_current = np.int(np.mean(nonzerox[good_left_inds]))
        if len(good_right_inds) > minpix:        
            rightx_current = np.int(np.mean(nonzerox[good_right_inds]))

    #--------------------------拟合多项式---------------------------
    # 链接索引数组 
    try:
        left_lane_inds = np.concatenate(left_lane_inds)
        right_lane_inds = np.concatenate(right_lane_inds)
    except ValueError:
        pass

    # 提取左右线像素位置
    leftx = nonzerox[left_lane_inds]
    lefty = nonzeroy[left_lane_inds] 
    rightx = nonzerox[right_lane_inds]
    righty = nonzeroy[right_lane_inds]

    return leftx, lefty, rightx, righty, out_img


def fit_polynomial(binary_warped):
    # 找到车道像素
    leftx, lefty, rightx, righty, out_img = find_lane_pixels(binary_warped)

    # np.polyfit 将二阶多项式拟合到每个多项式
    left_fit = np.polyfit(lefty, leftx, 2)
    right_fit = np.polyfit(righty, rightx, 2)
    print(left_fit)
    print(right_fit)

    # 生成用于绘图得 x 和 y值
    ploty = np.linspace(0, binary_warped.shape[0]-1, binary_warped.shape[0] )
    try:
        left_fitx = left_fit[0]*ploty**2 + left_fit[1]*ploty + left_fit[2]
        right_fitx = right_fit[0]*ploty**2 + right_fit[1]*ploty + right_fit[2]
    except TypeError:
        print('没有找到一条线')
        left_fitx = 1*ploty**2 + 1*ploty
        right_fitx = 1*ploty**2 + 1*ploty

    #--------------------------可视化---------------------------
    # 左右车道区域的颜色
    out_img[lefty, leftx] = [255, 0, 0]
    out_img[righty, rightx] = [0, 0, 255]

    # 在车道线上绘制左右多项式
    plt.plot(left_fitx, ploty, color='yellow')
    plt.plot(right_fitx, ploty, color='yellow')

    return out_img


out_img = fit_polynomial(binary_warped)

plt.imshow(out_img)
plt.show()
cv2.imwrite('output-warped-example.jpg', out_img)